"""Instantiate the main core lattice."""

import numpy as np

import openmc

from nuscale.materials import mats
# from .reflector import reflector_universes
from nuscale.assemblies import assembly_universes
from nuscale import surfaces


def core_geometry(control_rods='all_rods_out'):
    """Generate full core SMR geometry.
    Parameters
    -------
    control_rods : str
        all_rods_out - all groups of rods withdrawn (default)
        RE1 - RE1 group is inserted
        RE2 - RE2 group is inserted
        SH3 - SH3 group is inserted
        SH4 - SH4 group is inserted
        single_rod - a single rod is inserted as specified in benchmark's description
        all_rods_in - all groups of rods inserted

    Returns
    -------
    openmc.Geometry
        SMR full core geometry

    """
    assembly = assembly_universes()
    # reflector = reflector_universes()

    # Construct main core lattice
    core = openmc.RectLattice(name='Main core')
    lattice_pitch = surfaces.assembly_pitch
    core.lower_left = (-9*lattice_pitch/2, -9*lattice_pitch/2)
    core.pitch = (lattice_pitch, lattice_pitch)

    # Reflector cell
    cell = openmc.Cell(name='Reflector cell', fill=mats['Heavy reflector'])
    reflector = openmc.Universe(name='Reflector universe')
    reflector.add_cell(cell)

    universes = np.empty(shape=(9,9), dtype=openmc.Universe)
    universes[:,:] = reflector

    A01_y = np.array([2, 3, 4, 4, 4, 4, 5, 6])
    A01_x = np.array([4, 4, 2, 3, 5, 6, 4, 4])

    A02_y = np.array([3, 3, 5, 5])
    A02_x = np.array([3, 5, 3, 5])

    B01_y = np.array([2, 2, 3, 3, 5, 5, 6, 6])
    B01_x = np.array([3, 5, 2, 6, 2, 6, 3, 5])

    B02_y = np.array([1, 4, 4, 7])
    B02_x = np.array([4, 1, 7, 4])

    C01_y = np.array([1, 1, 3, 3, 5, 5, 7, 7])
    C01_x = np.array([3, 5, 1, 7, 1, 7, 3, 5])

    C02_y = np.array([2, 2, 6, 6])
    C02_x = np.array([2, 6, 2, 6])

    C03_y = np.array([4])
    C03_x = np.array([4])

    RE1_y = np.array([3, 4, 4, 5])
    RE1_x = np.array([4, 3, 5, 4])

    RE2_y = np.array([1, 4, 4, 7])
    RE2_x = np.array([4, 1, 7, 4])

    SH3_y = np.array([2, 3, 5, 6])
    SH3_x = np.array([3, 6, 2, 5])

    SH4_y = np.array([2, 3, 5, 6])
    SH4_x = np.array([5, 2, 6, 3])

    SR_y = np.array([3])
    SR_x = np.array([4])

    universes[A01_x, A01_y] = assembly['A01 Assembly no CRs']
    universes[A02_x, A02_y] = assembly['A02 Assembly no CRs']
    universes[B01_x, B01_y] = assembly['B01 Assembly no CRs']
    universes[B02_x, B02_y] = assembly['B02 Assembly no CRs']
    universes[C01_x, C01_y] = assembly['C01 Assembly no CRs']
    universes[C02_x, C02_y] = assembly['C02 Assembly no CRs']
    universes[C03_x, C03_y] = assembly['C03 Assembly no CRs']

    if control_rods == 'all_rods_out': # All rods out
        pass
    elif control_rods == 'RE1': # RE1 group inserted
        universes[RE1_x, RE1_y] = assembly['A01 Assembly with CRs']
    elif control_rods == 'RE2':
        universes[RE2_x, RE2_y] = assembly['B02 Assembly with CRs']
    elif control_rods == 'SH3':
        universes[SH3_x, SH3_y] = assembly['B01 Assembly with CRs']
    elif control_rods == 'SH4':
        universes[SH4_x, SH4_y] = assembly['B01 Assembly with CRs']
    elif control_rods == 'single_rod':
        universes[SR_x, SR_y] = assembly['A01 Assembly with CRs']
    elif control_rods == 'all_rods_in':
        universes[RE1_x, RE1_y] = assembly['A01 Assembly with CRs']
        universes[RE2_x, RE2_y] = assembly['B02 Assembly with CRs']
        universes[SH3_x, SH3_y] = assembly['B01 Assembly with CRs']
        universes[SH4_x, SH4_y] = assembly['B01 Assembly with CRs']

    core.universes = universes

    root_univ = openmc.Universe(universe_id=0, name='root universe')
    surfs = surfaces.surfs

    # Cylinder filled with core lattice
    cell = openmc.Cell(name='Main core')
    cell.fill = core
    cell.region = \
        -surfs['core barrel IR'] & +surfs['lower bound'] & -surfs['upper bound']
    root_univ.add_cell(cell)

    # Core barrel
    cell = openmc.Cell(name='core barrel')
    cell.fill = mats['SS']
    cell.region = (+surfs['core barrel IR'] & -surfs['core barrel OR'] &
                   +surfs['lower bound'] & -surfs['upper bound'])
    root_univ.add_cell(cell)

    # Return geometry
    return openmc.Geometry(root_univ)